source("../../R/mixture_irt_md.R")
library(tidyverse)
library(multidplyr)
library(dplyr, warn.conflicts = FALSE)
library(tictoc)
library(haven)
library(kableExtra)
set.seed(123)

## Load data and baseline estimates
mod_info <- read.csv("../../../module2010_questions.csv") %>%
  mutate(newquestion = str_replace(newquestion, "\\.", "_")) 

item_dat <- read_dta("../../../module2010_item_parameters.dta") %>%
  mutate(names = str_replace(names, "\\.", "_")) 

est_dat <- read_dta("../../../module2010_plus_estimates.dta") %>% 
  mutate( type = case_when( w1 > 0.5 ~ "Downsian",
                            w2 > 0.5 ~ "Conversian",
                            w3 > 0.5 ~ "Inattentive",
                            TRUE ~ "Mixed"),
          moderate = x > Hmisc::wtd.quantile(x, probs=1/3, weights=w1) & x < Hmisc::wtd.quantile(x, probs=2/3, weights=w1))


# Fit 1-D module 
res_1d <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=1)
est_dat <- est_dat %>% 
              mutate(w1_1d = res_1d$w[,1],
                     w2_1d = res_1d$w[,2],
                     w3_1d = res_1d$w[,3],
                     ivp_lk_1d = res_1d$ivp$lk,
                     fvp_lk_1d = res_1d$fvp,
                     irt_mlk_1d = res_1d$irt_mlk,
                     irt_lk_1d = res_1d$irt$lk,
                     x_1d = as.numeric(res_1d$irt$x))  

# Fit 2-D module 
res_2d <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=2)
est_dat <- est_dat %>% 
  mutate(w1_2d = res_2d$w[,1],
         w2_2d = res_2d$w[,2],
         w3_2d = res_2d$w[,3],
         ivp_lk_2d = res_2d$ivp$lk,
         fvp_lk_2d = res_2d$fvp,
         irt_mlk_2d = res_2d$irt_mlk,
         irt_lk_2d = res_2d$irt$lk,
         x1_2d = as.numeric(res_2d$irt$x[,1]),
         x2_2d = as.numeric(res_2d$irt$x[,2])) 


# Fist 1-D no-mix model
res_1d_nm <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=1, w_alpha=c(1,0,0))
est_dat <- est_dat %>% 
  mutate(w1_1d_nm = res_1d_nm$w[,1],
         w2_1d_nm = res_1d_nm$w[,2],
         w3_1d_nm = res_1d_nm$w[,3],
         ivp_lk_1d_nm = res_1d_nm$ivp$lk,
         fvp_lk_1d_nm = res_1d_nm$fvp,
         irt_mlk_1d_nm = res_1d_nm$irt_mlk,
         irt_lk_1d_nm = res_1d_nm$irt$lk,
         x1_1d_nm = as.numeric(res_1d_nm$irt$x[,1])) 


# Fit 2-D no mix model
res_2d_nm <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=2, w_alpha=c(1,0,0))
est_dat <- est_dat %>% 
  mutate(w1_2d_nm = res_2d_nm$w[,1],
         w2_2d_nm = res_2d_nm$w[,2],
         w3_2d_nm = res_2d_nm$w[,3],
         ivp_lk_2d_nm = res_2d_nm$ivp$lk,
         fvp_lk_2d_nm = res_2d_nm$fvp,
         irt_mlk_2d_nm = res_2d_nm$irt_mlk,
         irt_lk_2d_nm = res_2d_nm$irt$lk,
         x1_2d_nm = as.numeric(res_2d_nm$irt$x[,1]),
         x2_2d_nm = as.numeric(res_2d_nm$irt$x[,2])) 

# Fit NULL model
res_nd <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=1, w_alpha=c(0,1,1), iters=100, irt_iters=1)
est_dat <- est_dat %>% 
  mutate(w1_nd = res_nd$w[,1],
         w2_nd = res_nd$w[,2],
         w3_nd = res_nd$w[,3],
         ivp_lk_nd = res_nd$ivp$lk,
         fvp_lk_nd = res_nd$fvp,
         irt_mlk_nd = res_nd$irt_mlk)


## ----cache=TRUE, results='hide'-----------------------------------------------
res_ind <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=1, w_alpha=c(0,1,0), iters=1)
est_dat <- est_dat %>% 
  mutate(w1_ind = res_ind$w[,1],
         w2_ind= res_ind$w[,2],
         w3_ind = res_ind$w[,3], 
         ivp_lk_ind = res_ind$ivp$lk,
         fvp_lk_ind = res_ind$fvp,
         irt_mlk_ind = res_ind$irt_mlk)



# Mat rotation matrix for rigid rotation of the dimensions
r  <- coef(lm(I(scale(x_1d)) ~ I(scale(x1_2d_nm)) + I(scale(x2_2d_nm)), 
              data=est_dat))[2:3]
r <- matrix(c(r[1], r[2], -r[2], r[1]), 2, 2)/sqrt(sum(r^2))

gg <- ggplot(est_dat, aes(x = scale(x1_2d_nm)*r[1,1] + scale(x2_2d_nm)*r[2,1], 
                    y = scale(x1_2d_nm)*r[1,2] + scale(x2_2d_nm)*r[2,2],
                    col=w1_1d)) +
         geom_point(size=0.5, alpha=0.5) + 
         coord_equal() +
         theme_minimal() +
         theme(plot.margin=grid::unit(c(0,0,0,0), "mm")) +
         xlab("Dimension 1") + 
         ylab("Dimension 2") +
         scale_color_distiller(palette = "Spectral",
                               breaks = c(0.25, 0.5, 0.75)) +
         guides(color=guide_colorbar(title = "Pr(Downsian)",
                                     title.vjust = 0.8,
                                     label=TRUE,
                                     barwidth = 0.4)) 
pdf("two_dim_nomix.pdf", width=5, height=3.5)
gg
dev.off()

# Plot estimates of Pr(Downsian) across 1-D & 2-D models
gg <- ggplot(est_dat, aes(w1_1d, w1_2d)) + 
  geom_point(size=0.5, alpha=0.2) + 
  coord_equal() + 
  theme_minimal() + 
  xlab("Pr(Downsian) in 1-D model") +
  ylab("Pr(Downsian) in 2-D model") +
  geom_text(x=0.25, y=0.25, label = round(100*sum(est_dat$w1_1d < 0.5 & 
                                                   est_dat$w1_2d < 0.5)/
                                              NROW(est_dat))) + 
  geom_text(x=0.75, y=0.25, label = round(100*sum(est_dat$w1_1d > 0.5 & 
                                                  est_dat$w1_2d < 0.5)/
                                                NROW(est_dat))) + 
  geom_text(x=0.25, y=0.75, label = round(100*sum(est_dat$w1_1d < 0.5 & 
                                                  est_dat$w1_2d > 0.5)/
                                                NROW(est_dat))) + 
  geom_text(x=0.75, y=0.75, label = round(100*sum(est_dat$w1_1d > 0.5 & 
                                                  est_dat$w1_2d > 0.5)/
                                                NROW(est_dat))) +
  geom_hline(yintercept=0.5) +
  geom_vline(xintercept=0.5)

pdf("weights_1d_vs_2d.pdf", width=4, height=3)
gg
dev.off()


# Simulate 1-D data from the model
res_1d$w[res_1d$w<0] <- 0 # clean up some rounding errors
w <- apply(res_1d$w, 1, function(p) sample(1:3, 1, prob=p)) 
phat_irt <- plogis(cbind(1,res_1d$irt$x) %*% rbind(res_1d$irt$a, res_1d$irt$b))
o_phat <- dim(phat_irt); n <- o_phat[1]; k <- o_phat[2]
phat_ivp <- matrix(res_1d$ivp$ivp, ncol=k, nrow=n, byrow=TRUE)
phat <- (w==1)*phat_irt + (w==2)*phat_ivp + (w==3)*0.5
sim_dat <- (phat < matrix(runif(n*k), n, k)) + 0
na_mask <- est_dat %>% select(item_dat$names) %>% as.matrix() %>% is.na()
sim_dat[na_mask] <- NA

# Fit to 2-D no mix model to calibrated simulated I-D data
res_2d_nm_sim <- em_mix_irt(sim_dat, ndim=2, w_alpha=c(1,0,0))

# plot estimates from the simulated data
est_dat_wsim <- est_dat %>%
                    mutate(x1_2d_nm_sim = res_2d_nm_sim$irt$x[,1],
                           x2_2d_nm_sim = res_2d_nm_sim$irt$x[,2])
r  <- coef(lm(I(scale(x_1d)) ~ I(scale(x1_2d_nm_sim)) + I(scale(x2_2d_nm_sim)), 
              data=est_dat_wsim))[2:3]
r <- matrix(c(r[1], r[2], -r[2], r[1]), 2, 2)/sqrt(sum(r^2))

gg <- ggplot(est_dat_wsim, aes(x = scale(x1_2d_nm_sim)*r[1,1] + scale(x2_2d_nm_sim)*r[2,1], 
                    y = -(scale(x1_2d_nm_sim)*r[1,2] + scale(x2_2d_nm_sim)*r[2,2]),
                    col=w1_1d)) +
         geom_point(size=0.5, alpha=0.5) + 
         coord_equal() +
         theme_minimal() +
         xlab("Dimension 1") + 
         ylab("Dimension 2") +
         xlim(-3, 4.5) +
         scale_color_distiller(palette = "Spectral",
                               breaks = c(0.25, 0.5, 0.75)) +
         guides(color=guide_colorbar(title = "Pr(Downsian)",
                                     title.vjust = 0.8,
                                     label=TRUE,
                                     barwidth = 0.4)) 


pdf("two_dim_nomix_sim.pdf", width=5, height=3.5)
gg
dev.off()

# Calculate perplexity
res_0d <- em_mix_irt(est_dat %>% select(item_dat$names) %>% as.matrix(), ndim=1, w_alpha=c(0,1,0), irt_iters=1, iters=1)
perplex_0d <- cv_perplexity(est_dat %>% select(item_dat$names) %>% as.matrix(),
                         fit=res_0d, iters=1, irt_iter=1)
perplex_1d <- cv_perplexity(est_dat %>% select(item_dat$names) %>% as.matrix(),
                         fit=res_1d)
perplex_2d <- cv_perplexity(est_dat %>% select(item_dat$names) %>% as.matrix(),
                         fit=res_2d)
perplex_2d_nm <- cv_perplexity(est_dat %>% select(item_dat$names) %>%
                              as.matrix(), 
                            fit=res_2d_nm)

# Make perplexity table 
bind_rows(perplex_0d %>% mutate(model="Null", i=1), 
          perplex_1d %>% mutate(model="1-D (mix.)", i=2),
          perplex_2d_nm %>% mutate(model="2-D (no mix.)", i=3),
          perplex_2d %>% mutate(model="2-D (mix.)", i=4)) %>%
          group_by(model) %>%
          summarize(i = i[1],
                    perp = perp[estimate=="CV"],
                    loglik = loglik[estimate=="In sample"]) %>%
  arrange(i) %>%
  ungroup() %>%
  select(model, loglik, perp) %>%
  kbl( digits = c(0,0,2),
       format="latex", 
       booktabs=TRUE,
       col.names = c("Model", "Log-likelihood", "Perplexity")) %>%
    kable_styling() -> perp_table

cat(paste0(as.character(perp_table, "\n")))

write_file(as.character(perp_table) %>% 
             str_replace_all("\\\\[begind]+\\{table\\}\\s*(\\n|$)", ""),
           "cces2010module_perplexity.tex")

